# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys

now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict

import torch
from pytorch_lightning import LightningModule
from gptsovits.AR.models.t2s_model import Text2SemanticDecoder
from gptsovits.AR.modules.optim import ScaledAdam

class Text2SemanticLightningModule(LightningModule):
    def __init__(self, config, is_train=True):
        super().__init__()
        self.config = config
        self.top_k = 3
        self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
        pretrained_s1 = config.get("pretrained_s1")
        if pretrained_s1 and is_train:
            # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
            print(
                self.load_state_dict(
                    torch.load(pretrained_s1, map_location="cpu")["weight"]
                )
            )
        # if is_train:
        #     self.automatic_optimization = False
        #     self.save_hyperparameters()
        #     self.eval_dir = output_dir / "eval"
        #     self.eval_dir.mkdir(parents=True, exist_ok=True)

    def training_step(self, batch: Dict, batch_idx: int):
        opt = self.optimizers()
        scheduler = self.lr_schedulers()
        forward=self.model.forward_old
        loss, acc = forward(
            batch["phoneme_ids"],
            batch["phoneme_ids_len"],
            batch["semantic_ids"],
            batch["semantic_ids_len"],
            batch["bert_feature"],
        )
        self.manual_backward(loss)
        if batch_idx > 0 and batch_idx % 4 == 0:
            opt.step()
            opt.zero_grad()
            scheduler.step()

        self.log(
            "total_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.log(
            "lr",
            scheduler.get_last_lr()[0],
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.log(
            f"top_{self.top_k}_acc",
            acc,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

    def validation_step(self, batch: Dict, batch_idx: int):
        return

    # def configure_optimizers(self):
        # model_parameters = self.model.parameters()
        # parameters_names = []
        # parameters_names.append(
        #     [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
        # )
        # lm_opt = ScaledAdam(
        #     model_parameters,
        #     lr=0.01,
        #     betas=(0.9, 0.95),
        #     clipping_scale=2.0,
        #     parameters_names=parameters_names,
        #     show_dominant_parameters=False,
        #     clipping_update_period=1000,
        # )

        # return {
        #     "optimizer": lm_opt,
        #     "lr_scheduler": {
        #         "scheduler": self.lr_scheduler,
        #     },
        # }
